###########################################################
#### figure_s3.R                                       ####
#### Generates Figure s3 and accompanying descriptives ####
###########################################################

tab_in = import(here("Data","var_data.csv"))
vars_fact = tab_in |> 
  filter(is_fact == 1) |> 
  distinct(var) |> 
  pull(var)

vars_fact = c(vars_fact, "state", "COUNTYFIP")

data_rf_cat = data |> 
  select(affpol_bin, urbanicity2, bidenvote_cnty, efi_cnty, bidenvote,
         efi_state, divided, professional, mds1, mds2, h_diffs,
         index, age, gender, race_cat, pid, pew_bornagain, gdp_per_cap, unemployment,
         general_trust, institutional_corruption, institutional_response, vote_importance, pride,
         fair_treatment, strong, newsint, state, COUNTYFIP) |> 
  mutate(across(all_of(vars_fact), as.factor)) |> 
  filter(!is.na(affpol_bin)) |> 
  drop_na()

set.seed(120)
data_split_cat = initial_split(data_rf_cat, prop = .8)
data_train_cat = training(data_split_cat)
data_test_cat = testing(data_split_cat)

# Define Recipes
fc_indiv_st  = as.formula("affpol_bin ~ bidenvote + efi_state + divided + professional + mds1 + mds2 + h_diffs +
               index + age + gender + race_cat + pid + pew_bornagain + gdp_per_cap + unemployment")
fc_indiv_geo = as.formula("affpol_bin ~ urbanicity2 + bidenvote_cnty + efi_cnty + bidenvote +
               efi_state + divided + professional + mds1 + mds2 + h_diffs +
               index + age + gender + race_cat + pid + pew_bornagain + gdp_per_cap + unemployment")
fc_indiv     = as.formula("affpol_bin ~ age + gender + race_cat + pid + pew_bornagain")
fc_st        = as.formula("affpol_bin ~ bidenvote + efi_state + divided + professional + 
                         mds1 + mds2 + h_diffs + index + gdp_per_cap + unemployment")
fc_geo       = as.formula("affpol_bin ~ urbanicity2 + bidenvote_cnty + efi_cnty + bidenvote +
               efi_state + divided + professional + mds1 + mds2 + h_diffs + index + gdp_per_cap + unemployment")
fc_psych     = as.formula('affpol_bin ~ pid + general_trust + institutional_corruption + institutional_response + 
                         vote_importance + pride + fair_treatment + strong + newsint')

fc_list = list(fc_indiv_st, fc_indiv_geo, fc_indiv, fc_st, fc_geo, fc_psych)
set_vars = c("Individual + State", "Everything (Non-Psych)", "Individual Only",
             "State Only", "All Geographic", "Psychological")

# Set cross-validation for tuning
set.seed(101)
folds = vfold_cv(data_train_cat)

# fit each model, extract result
rf_output_cat = map_dfr(1:6, \(x){
  print(set_vars[x])
  # Set recipe
  rec_rf = recipe(fc_list[[x]], data = data_train_cat)
  # Set Model
  rf_model = rand_forest(mtry = tune(),
                         min_n = tune(),
                         trees = 1000) |> 
    set_engine("ranger",
               num.threads = parallel::detectCores(),
               importance = "impurity",
               seed = 613) |> 
    set_mode("classification")
  
  # Set workflow(s)
  rf_wflow = workflow() |> 
    add_model(rf_model) |> 
    add_recipe(rec_rf)
  
  # Tune 
  set.seed(345)
  tune_res = tune_grid(
    rf_wflow,
    resamples = folds,
    grid = 10
  )
  # Finalize
  best_roc = select_best(tune_res, "roc_auc")
  final_rf = finalize_model(rf_model, best_roc)
  final_wf = workflow() |> 
    add_model(final_rf) |> 
    add_recipe(rec_rf)
  final_res = final_wf |> 
    last_fit(data_split_cat)
  # Evaluate
  collect_metrics(final_res) |> 
    mutate(set = set_vars[x])
})

export(rf_output_cat, here("Output","rf_output_cat.rds"))
prop.table(table(data_train_cat$affpol_bin))

fig_s3 = rf_output_cat |> 
  add_case(.metric = "accuracy", set = 'Null',
           .estimate = prop.table(table(data_train_cat$affpol_bin)) |> max()) |> 
  filter(.metric == "accuracy") |> 
  ggplot(aes(x = .estimate, y = reorder(set, .estimate))) +
  geom_bar(stat = 'identity') +
  geom_label(aes(x = .estimate - .03, label = round(.estimate, 3))) +
  labs(x = "Accuracy", y = "Predictor Set") +
  theme_prl()

print(fig_s3)

ggsave(here("Plots","figure_s3.pdf"), fig_s3,
       dpi = 600, units = "in", height = 4, width = 6)
